"""@plugin POODLEVulnerabilityTesterPlugin
This file contains classes for plugin, which tests server(s) for vulnerability CVE-2014-3566.

@author: Bc. Pavel Soukup
"""

from operator import attrgetter
from xml.etree.ElementTree import Element

from nassl import SSLV2, SSLV3, TLSV1, TLSV1_1, TLSV1_2, SSL_MODE_SEND_FALLBACK_SCSV, _nassl
from nassl.ssl_client import SslClient
from sets import Set

from sslyze.plugins.common_new_plugin_info import AcceptCipher, RejectCipher, OPENSSL_TO_RFC_NAMES_MAPPING

from sslyze.plugins import plugin_base
from sslyze.plugins.plugin_base import PluginResult
from sslyze.utils.ssl_connection import SSLHandshakeRejected, SSLConnection
from sslyze.utils.thread_pool import ThreadPool

class POODLEVulnerabilityTesterPlugin(plugin_base.PluginBase):
    """class POODLEVulnerabilityTesterPlugin
    
    This class inherited from abstract class PluginBase. Instance of this class tests server for vulnerability CVE-2014-3566 and makes decision if the server is vulnerable.
    """
    interface = plugin_base.PluginInterface(
        "POODLEVulnerabilityTesterPlugin",
        "Scans the server(s) and checks if requirements for POODLE attack are satisfied.")
    interface.add_command(
        command="poodle",
        help="Tests server(s) for CVE-2014-3566 vulnerability.")

    MAX_THREADS=15

    def process_task(self, server_connectivity_info, plugin_command, option_dict=None):
        if option_dict and 'verbose' in option_dict.keys():
            verbose_mode = option_dict['verbose']
        else:
            verbose_mode = False
        ssl3_support = self.test_SSLv3_support(server_connectivity_info)
        support_vulnerable_ciphers = None
        if ssl3_support:
            cipher_list = self.get_ssl3_cipher_list()
            thread_pool = ThreadPool()
            for cipher in cipher_list:
                thread_pool.add_job((self._test_ciphersuite, (server_connectivity_info,cipher)))
            thread_pool.start(nb_threads=min(len(cipher_list),self.MAX_THREADS))

            accept_ciphers = []
            reject_ciphers = []
            if verbose_mode:
                print '  VERBOSE MODE PRINT'
                print '  ------------------'
            for completed_job in thread_pool.get_result():
                (job, cipher_result) = completed_job
                if isinstance(cipher_result, AcceptCipher):
                    accept_ciphers.append(cipher_result)
                elif isinstance(cipher_result,RejectCipher):
                    reject_ciphers.append(cipher_result)
                else:
                    raise ValueError("Unexpected result")
                if verbose_mode:
                    cipher_result.print_cipher()

            if verbose_mode:
                print '  ----------------------'
                print '  END VERBOSE MODE PRINT'
                print '  ----------------------'

            for error_job in thread_pool.get_error():
                (_, exception) = error_job
                raise exception

            thread_pool.join()
            support_vulnerable_ciphers = self.get_vulnerable_ciphers(accept_ciphers)

        is_vulnerable = ssl3_support and (support_vulnerable_ciphers is not None or len(support_vulnerable_ciphers) > 0)

        return POODLEVulnerabilityTesterResult(server_connectivity_info, plugin_command,option_dict, ssl3_support, is_vulnerable, support_vulnerable_ciphers)


    def _test_ciphersuite(self, server_connectivity_info, cipher):
        """This function is used by threads to it investigates with support the cipher suite on server. Returns instance of class AcceptCipher or RejectCipher.

            Args:
            server_connectivity_info (ServerConnectivityInfo): contains information for connection on server
            cipher (str): contains OpenSSL shortcut for identification cipher suite
        """
        ssl_conn = server_connectivity_info.get_preconfigured_ssl_connection(override_ssl_version=SSLV3)
        ssl_conn.set_cipher_list(cipher)
        try:
            ssl_conn.connect()
        except SSLHandshakeRejected as e:
            cipher_result = RejectCipher(OPENSSL_TO_RFC_NAMES_MAPPING[SSLV3].get(cipher, cipher), str(e))
        except Exception as e:
            cipher_result = RejectCipher(OPENSSL_TO_RFC_NAMES_MAPPING[SSLV3].get(cipher, cipher), str(e))
        else:
            cipher_result = AcceptCipher(OPENSSL_TO_RFC_NAMES_MAPPING[SSLV3].get(cipher, cipher))
        finally:
            ssl_conn.close()
        return cipher_result

    def get_ssl3_cipher_list(self):
        """Returns list of cipher suites available for protocol SSL 3.0
        """
        ssl_client = SslClient(ssl_version=SSLV3)
        ssl_client.set_cipher_list('SSLv3')
        return ssl_client.get_cipher_list()

    def test_SSLv3_support(self, server_connectivity_info):
        """Tests if protocol SSL 3.0 is supported by server. Returns true if server supports protocol SSL 3.0 otherwise returns false.

            Args:
            server_connectivity_info (ServerConnectivityInfo): contains information for connection on server            
        """
        if server_connectivity_info.highest_ssl_version_supported <= SSLV3:
            return True
        ssl_conn = server_connectivity_info.get_preconfigured_ssl_connection(override_ssl_version=SSLV3)
        try:
            ssl_conn.connect()
        except SSLHandshakeRejected as e:
            ssl3 = False
        else:
            ssl3 = True
        finally:
            ssl_conn.close()
        return ssl3

    def get_vulnerable_ciphers(self, accept_ciphers):
        """Returns set of cipher suites, which using CBC mode and are supported by server.

            Args:
            accept_ciphers(array): contains cipher suites, which are supported by server
        """
        result_set = []
        for cipher in accept_ciphers:
            if cipher.use_CBC_mode:
                result_set.append([cipher._cipher_rfc_name])
        return result_set

class POODLEVulnerabilityTesterResult(PluginResult):
    """class POODLEVulnerabilityTesterResult
    
    This class is subclass PluginResult. It's used to return result of test, which is made in class POODLEVulnerabilityTesterPlugin.
    """
    COMMAND_TITLE = 'Vulnerability CVE-2014-3566'
    CIPHER_LIST_TITLE_FORMAT = '      {section_title:<32}'.format
    CIPHER_LINE_FORMAT = u'        {cipher_name:<50}'.format
    LINE_FORMAT = u'    {technology:<40}{result:<20}'.format
    def __init__(self, server_connectivity_info, plugin_command, plugin_option, ssl3_support, is_vulnerable, support_vulnerable_ciphers = None):
        super(POODLEVulnerabilityTesterResult, self).__init__(server_connectivity_info, plugin_command, plugin_option)
        self.support_vulnerable_ciphers = support_vulnerable_ciphers
        self.ssl3_support = ssl3_support
        self.is_vulnerable = is_vulnerable

    def as_text(self):
        txt_output = [self.PLUGIN_TITLE_FORMAT(self.COMMAND_TITLE)]
        poodle_txt = 'VULNERABLE - server is vulnerable to POODLE attack' \
            if self.is_vulnerable \
            else 'OK - Not VULNERABLE to POODLE attack'
        txt_output.append(self.FIELD_FORMAT("",poodle_txt))
        result_ssl_txt = 'Is supported' \
            if self.ssl3_support \
            else 'Is not supported'
        txt_output.append(self.LINE_FORMAT(technology='SSLv3 protocol:', result=result_ssl_txt))
        if self.ssl3_support:
            txt_output.append(self.CIPHER_LIST_TITLE_FORMAT(section_title='Vulnerable cipher/ciphers:'))
            if len(self.support_vulnerable_ciphers) == 0:
                txt_output.append(self.CIPHER_LINE_FORMAT(cipher_name='None CBC cipher supported'))
            else:
                for cipher in self.support_vulnerable_ciphers:
                    txt_output.append(self.CIPHER_LINE_FORMAT(cipher_name=cipher))
        return txt_output

    def as_xml(self):
        xml_output = Element(self.plugin_command, title=self.COMMAND_TITLE)
        xml_output.append(Element('vulnerable', isVulnerable=str(self.is_vulnerable)))
        xml_output.append(Element('sslv3', support=str(self.ssl3_support)))
        if len(self.support_vulnerable_ciphers) > 0:
            ciphers_xml = Element('vulnerableCipherSuites')
            for c in self.support_vulnerable_ciphers:
                ciphers_xml.append(Element('cipherSuite',name=c))
            xml_output.append(ciphers_xml)
        return xml_output
